Skip to content

[Performance] Improve MiMo-Audio tokenizer decoding performance#2183

Merged
hsliuustc0106 merged 111 commits into
vllm-project:mainfrom
qibaoyuan:tok_cg
May 11, 2026
Merged

[Performance] Improve MiMo-Audio tokenizer decoding performance#2183
hsliuustc0106 merged 111 commits into
vllm-project:mainfrom
qibaoyuan:tok_cg

Conversation

@qibaoyuan
Copy link
Copy Markdown
Contributor

@qibaoyuan qibaoyuan commented Mar 25, 2026

Purpose

To improve the decoding capability of the audio tokenizer in the MiMo-Audio model, we focus on optimizing its efficiency, as it is frequently invoked in asynchronous scenarios. Improving its performance is therefore critical. Our approach leverages CUDA Graphs to accelerate execution.

Key changes include:

  • Attention.forward_fixed — Replaces flash_attn_varlen_func with F.scaled_dot_product_attention, operating on 3D tensors [B, L, D], thereby avoiding variable-length packing.
  • TransformerLayer.forward_fixed — Combines self_attn.forward_fixed with the feed-forward network (FFN).
  • CausalConvTranspose1d.forward_fixed — Applies transposed convolution directly on 3D tensors without using masked_select.
  • TransformerVocos.forward_fixed — Implements a mask-free forward path for the vocoder.
  • AudioDecoder.forward_fixed — Constructs the full decoder pipeline: dconv1 → transformer layers → dconv2 → vocoder.
  • MiMoAudioTokenizer.decode_fixed — Wraps the complete decoding process, including decode_vq, padding, and decoder.forward_fixed.

Test Plan

export MIMO_AUDIO_TOKENIZER_PATH="XiaomiMiMo/MiMo-Audio-Tokenizer"

python3 -u end2end.py \
--stage-configs-path ./vllm_omni/model_executor/stage_configs/mimo_audio.yaml  \
--model  "XiaomiMiMo/MiMo-Audio-7B-Instruct" \
--query-type tts_sft_with_audio \
--audio_path ./examples/offline_inference/mimo_audio/beijing.mp3 \
--text "我还知道东北有杀猪菜,是把猪血肠、五花肉、酸菜等放在一块炖的,味道很浓郁。"

Test Result

Request ID: 0_3581f0d8-1ec1-4063-a223-72fa6a95b4a1, Text saved to ./output_audio/tts_sft_with_audio/0_3581f0d8-1ec1-4063-a223-72fa6a95b4a1.txt

Request ID: 0_3581f0d8-1ec1-4063-a223-72fa6a95b4a1, Audio saved to ./output_audio/tts_sft_with_audio/0_3581f0d8-1ec1-4063-a223-72fa6a95b4a1.wav

0_3581f0d8-1ec1-4063-a223-72fa6a95b4a1.wav


Essential Elements of an Effective PR Description Checklist
  • [ x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [ x] The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • [x ] The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

qibaoyuan and others added 30 commits March 6, 2026 15:30
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
# Conflicts:
#	vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

will this pr solve the acc issue?

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

can you provide the e2e improvement from this PR?

@qibaoyuan
Copy link
Copy Markdown
Contributor Author

qibaoyuan commented May 7, 2026

can you provide the e2e improvement from this PR?

CUDA Graph vs Eager Execution Performance Comparison

Compared with Eager execution, CUDA Graph delivers significant latency reduction and inference acceleration in both non-streaming and streaming modes.

Performance Table

Mode Eager (ms) CUDA Graph (ms) Speedup
Non-streaming 60 5.5 10.93×
Streaming 219.3 31.1 7.05×

Key Observations

  • Non-streaming mode: CUDA Graph reduces latency from 60 ms to 5.5 ms, achieving a 10.93× speedup.
  • Streaming mode: Latency decreases from 219.3 ms to 31.1 ms, resulting in a 7.05× acceleration.
  • Overall, CUDA Graph significantly reduces runtime overhead and improves execution efficiency, especially for latency-sensitive inference workloads.

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit e108802 into vllm-project:main May 11, 2026
8 checks passed
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…-project#2183)

Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Co-authored-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Galleons2029 added a commit to Galleons2029/vllm-omni-ljl that referenced this pull request May 18, 2026
…-project#2183)

Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
Co-authored-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants